# 作者 :南雨
# 时间 : 2022/6/23 9:26
import pandas as pd
from dzj.trec_qa.trec_constant import Trec_constant
from sklearn.utils import shuffle


def category_to_id():
    """
    标签数值化
    :return:["manner":0,"sport":20]
    """
    label_to_id = dict(zip(Trec_constant.labels, range(len(Trec_constant.labels))))
    sub_labels_to_id = dict(zip(Trec_constant.sub_labels, range(len(Trec_constant.sub_labels))))
    return label_to_id, sub_labels_to_id


def load_data(path):
    """
    加载数据
    :param path: 数据集路径
    :return:DataFrame
    """
    total_list = []
    with open(path, encoding="GBK") as f:
        for i in f:
            temp_list = []
            text = i.replace(":", "-----", 1).strip()
            text_list = text.split("-----")
            temp_list.append(category_to_id()[0][text_list[0]])
            sub_list = text_list[1].split(" ", 1)
            temp_list.append(category_to_id()[1][sub_list[0]])
            temp_list.append(sub_list[1])
            total_list.append(temp_list)
    return total_list


def get_data():
    """
    获取数据
    :return:DataFrame
    """
    train_list = load_data(Trec_constant.train_path)
    test_list = load_data(Trec_constant.test_path)
    all_list = train_list + test_list
    train_df = pd.DataFrame(train_list, columns=["labels", "sub_labels", "text"])
    test_df = pd.DataFrame(test_list, columns=["labels", "sub_labels", "text"])
    return train_df, test_df, all_list


def get_all_text_list(all_list):
    text_list = []
    for t in all_list:
        temp_list = []
        temp_list.append(t[2])
        text_list.append(temp_list)
    return text_list


def get_all():
    train_df, test_df, all_list = get_data()
    data_df = pd.concat([train_df, test_df], axis=0)
    data_df = shuffle(data_df).reset_index(drop=True)
    text_list = get_all_text_list(all_list)
    return data_df, text_list
